# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os
import os.path as osp

import numpy as np

from datumaro.components.annotation import AnnotationType, Label, LabelCategories
from datumaro.components.converter import Converter
from datumaro.components.errors import MediaTypeError
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.media import Image
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file


class MnistCsvPath:
    IMAGE_SIZE = 28
    NONE_LABEL = -1


class MnistCsvExtractor(SourceExtractor):
    def __init__(self, path, subset=None):
        if not osp.isfile(path):
            raise FileNotFoundError("Can't read annotation file '%s'" % path)

        if not subset:
            file_name = osp.splitext(osp.basename(path))[0]
            subset = file_name.rsplit("_", maxsplit=1)[-1]

        super().__init__(subset=subset)
        self._dataset_dir = osp.dirname(path)

        self._categories = self._load_categories()

        self._items = list(self._load_items(path).values())

    def _load_categories(self):
        if has_meta_file(self._dataset_dir):
            return {
                AnnotationType.label: LabelCategories.from_iterable(
                    parse_meta_file(self._dataset_dir).keys()
                )
            }

        label_cat = LabelCategories()

        labels_file = osp.join(self._dataset_dir, "labels.txt")
        if osp.isfile(labels_file):
            with open(labels_file, encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    label_cat.add(line)
        else:
            for i in range(10):
                label_cat.add(str(i))

        return {AnnotationType.label: label_cat}

    def _load_items(self, path):
        items = {}
        with open(path, "r", encoding="utf-8") as f:
            annotation_table = f.readlines()

        metafile = osp.join(self._dataset_dir, "meta_%s.csv" % self._subset)
        meta = []
        if osp.isfile(metafile):
            with open(metafile, "r", encoding="utf-8") as f:
                meta = f.readlines()

        for i, data in enumerate(annotation_table):
            data = data.split(",")
            item_anno = []
            label = int(data[0])
            if label != MnistCsvPath.NONE_LABEL:
                item_anno.append(Label(label))

            if 0 < len(meta):
                meta[i] = meta[i].strip().split(",")

            # support for single-channel image only
            image = None
            if 1 < len(data):
                if 0 < len(meta) and 1 < len(meta[i]):
                    image = np.array([int(pix) for pix in data[1:]], dtype="uint8").reshape(
                        int(meta[i][-2]), int(meta[i][-1])
                    )
                else:
                    image = np.array([int(pix) for pix in data[1:]], dtype="uint8").reshape(28, 28)

            if image is not None:
                image = Image(data=image)

            if 0 < len(meta) and len(meta[i]) in [1, 3]:
                i = meta[i][0]

            items[i] = DatasetItem(id=i, subset=self._subset, media=image, annotations=item_anno)
        return items


class MnistCsvImporter(Importer):
    @classmethod
    def find_sources(cls, path):
        return cls._find_sources_recursive(
            path, ".csv", "mnist_csv", file_filter=lambda p: not osp.basename(p).startswith("meta")
        )


class MnistCsvConverter(Converter):
    DEFAULT_IMAGE_EXT = ".png"

    def apply(self):
        if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image):
            raise MediaTypeError("Media type is not an image")

        os.makedirs(self._save_dir, exist_ok=True)
        if self._save_dataset_meta:
            self._save_meta_file(self._save_dir)

        for subset_name, subset in self._extractor.subsets().items():
            data = []
            item_ids = {}
            image_sizes = {}
            for item in subset:
                anns = [a.label for a in item.annotations if a.type == AnnotationType.label]
                label = MnistCsvPath.NONE_LABEL
                if anns:
                    label = anns[0]

                if item.media and self._save_media:
                    image = item.media
                    if not image.has_data:
                        data.append([label, None])
                    else:
                        if (
                            image.data.shape[0] != MnistCsvPath.IMAGE_SIZE
                            or image.data.shape[1] != MnistCsvPath.IMAGE_SIZE
                        ):
                            image_sizes[len(data)] = [image.data.shape[0], image.data.shape[1]]
                        image = image.data.reshape(-1).astype(np.uint8).tolist()
                        image.insert(0, label)
                        data.append(image)
                else:
                    data.append([label])

                if item.id != str(len(data) - 1):
                    item_ids[len(data) - 1] = item.id

            anno_file = osp.join(self._save_dir, "mnist_%s.csv" % subset_name)
            self.save_in_csv(anno_file, data)

            # it is't in the original format,
            # this is for storng other names and sizes of images
            if len(item_ids) or len(image_sizes):
                meta = []
                if len(item_ids) and len(image_sizes):
                    # other names and sizes of images
                    size = [MnistCsvPath.IMAGE_SIZE, MnistCsvPath.IMAGE_SIZE]
                    for i in range(len(data)):
                        w, h = image_sizes.get(i, size)
                        meta.append([item_ids.get(i, i), w, h])

                elif len(item_ids):
                    # other names of images
                    for i in range(len(data)):
                        meta.append([item_ids.get(i, i)])

                elif len(image_sizes):
                    # other sizes of images
                    size = [MnistCsvPath.IMAGE_SIZE, MnistCsvPath.IMAGE_SIZE]
                    for i in range(len(data)):
                        meta.append(image_sizes.get(i, size))

                metafile = osp.join(self._save_dir, "meta_%s.csv" % subset_name)
                self.save_in_csv(metafile, meta)

        self.save_labels()

    def save_in_csv(self, path, data):
        with open(path, "w", encoding="utf-8") as f:
            for row in data:
                f.write(",".join([str(p) for p in row]) + "\n")

    def save_labels(self):
        labels_file = osp.join(self._save_dir, "labels.txt")
        with open(labels_file, "w", encoding="utf-8") as f:
            f.writelines(
                l.name + "\n"
                for l in self._extractor.categories().get(AnnotationType.label, LabelCategories())
            )
